import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from pdb import set_trace
from matplotlib.ticker import MaxNLocator
from plotting_result_data import mean_confidence_interval, unpack_confidence_intervals

# Loading the data from the CSV files (assuming the files are named 'data1.csv' and 'data2.csv')
df_filename = 'all_runs_data_5.csv'   # in these results, bidder 1 has a multiplier of 2, and bidder 0 of 1.
# df= pd.read_csv(f'all_runs_data_4.csv')  # used for the table that includes the upper limit metric
df = pd.read_csv(df_filename)   # in these results, bidder 1 has a multiplier of 2, and bidder 0 of 1. 
metrics = ['advertiser 0 expected value', 'advertiser 1 expected value']
# metrics = ['advertiser 0 participating value gain', 'advertiser 1 participating value gain']
# metrics = ['advertiser 0 payment zero bid offset', 'advertiser 1 payment zero bid offset']
# metrics = ['advertiser 0 utility gain zero bid offset', 'advertiser 1 utility gain zero bid offset']
# metrics = [('advertiser 0 expected value', 'advertiser 0 payment zero bid offset'), ('advertiser 1 expected value', 'advertiser 1 payment zero bid offset')]  # this way we can calculate the utility 
# metric options: 'total advertiser participating value gain', 'total advertiser utility gain zero bid offset', 
# 'total payment zero bid offset',  'total payment no offset', 'sequence log probability', 'reference LLM log probability', 

use_input_expansion = True # set to true if you want to use the contextual LLM   
colors = ['blue', 'orange', 'green']

# Plot hyperparameters
if 'advertiser 0 expected value' in metrics:
    metric_plot_name = 'Advertiser Reward'
    metric_plot_name = 'Advertiser Value'
elif 'advertiser 0 payment zero bid offset' in metrics:
    metric_plot_name = 'Advertiser Payment'
elif 'advertiser 0 utility gain zero bid offset' in metrics:
    metric_plot_name = 'Advertiser Utility Gain'
elif 'advertiser 0 participating value gain' in metrics:
    metric_plot_name = 'Advertiser Reward Gain'
elif type(metrics[0]) == tuple:
    metric_plot_name = 'Advertiser Utility'
else: 
    metric_plot_name = metrics[0]


upper_limit_metric_name = 'reference LLM log probability'
font_size = 25
tick_size = 18
save_plot = False   

no_gain_in_name = False  # set to True if the metric name does not include the word 'Gain'
if metrics[0] in ['sequence log probability', 'reference LLM log probability']:
    add_upper_limit = True
else:
    add_upper_limit = False


# def mean_confidence_interval(data, confidence=0.95):
#     mean = np.mean(data)
#     sem = stats.sem(data)
#     margin_of_error = sem * stats.t.ppf((1 + confidence) / 2., len(data)-1)
#     return mean, mean - margin_of_error, mean + margin_of_error

# # Unpack confidence intervals for plotting
# def unpack_confidence_intervals(confidence_intervals):
#     means = confidence_intervals.apply(lambda x: x[0])
#     lower_bounds = confidence_intervals.apply(lambda x: x[1])
#     upper_bounds = confidence_intervals.apply(lambda x: x[2])
#     return means, lower_bounds, upper_bounds


# df_with_expansion = df[df['use_input_expansion'] == True]
# df_without_expansion = df[df['use_input_expansion'] == False]

df_to_use = df[df['use_input_expansion'] == use_input_expansion]
grouped_data = df_to_use.groupby('samples used')

plt.figure(figsize=(12, 8))

for i,metric in enumerate(metrics):
    if type(metric) == tuple:
        # Subtract the two metrics first, then group the differences
        utility_diff = df_to_use[metric[0]] - df_to_use[metric[1]]
        grouped_diff = utility_diff.groupby(df_to_use['samples used'])

        # Now apply the confidence interval calculation to the grouped differences
        ci = grouped_diff.apply(mean_confidence_interval)

        # Unpack the confidence intervals
        means, lower, upper = unpack_confidence_intervals(ci)

        plt.plot(means.index, means, label=f'Advertiser {i}', color=colors[i])
        plt.fill_between(means.index, lower, upper, color=colors[i], alpha=0.2)


    else:
        ci = grouped_data[metric].apply(mean_confidence_interval)
        means, lower, upper= unpack_confidence_intervals(ci)

        plt.plot(means.index, means, label=f'Advertiser {i}', color=colors[i])
        plt.fill_between(means.index, lower, upper, color=colors[i], alpha=0.2)



# For upper limit metric
if add_upper_limit:
    # the upper limit metric is the reference LLM log probability, which generated the sequences only in the case of the no expansion
    ci_without_expansion_upper_limit = grouped_data[upper_limit_metric_name].apply(mean_confidence_interval)  
    means_without_upper_limit, lower_without_upper_limit, upper_without_upper_limit = unpack_confidence_intervals(ci_without_expansion_upper_limit)

    # For all entries, the upper limit metric should be set to its 0-th entry 
    # set_trace()
    means_without_upper_limit = means_without_upper_limit[1].repeat(len(means_without_upper_limit))
    lower_without_upper_limit = lower_without_upper_limit[1].repeat(len(lower_without_upper_limit))
    upper_without_upper_limit = upper_without_upper_limit[1].repeat(len(upper_without_upper_limit))
    # set_trace()




# Plot for upper limit metric
if add_upper_limit:
    plt.plot(means.index, means_without_upper_limit, label=r'Using $\hat{\pi}_{\text{opt}}(\cdot | x)$', color='green', linestyle='--')
    plt.fill_between(means.index, lower_without_upper_limit, upper_without_upper_limit, color='green', alpha=0.2)


plt.xlabel('Candidate Replies Generated', fontsize=font_size)
# remove the word 'gain' from the metric name to show in the y-axis
if no_gain_in_name:
    metric_name_to_show = metric_plot_name.replace('Gain', '')
else:
    metric_name_to_show = metric_plot_name

plt.ylabel(metric_name_to_show, fontsize=font_size)
# plt.title(f'{metric_plot_name} vs. Generated Candidate Sequences', fontsize=font_size)  # No need for title, as this will be added in the paper
plt.legend(fontsize=font_size)
plt.grid(True)


# Setting x-axis limits slightly beyond the actual data points for better visibility
plt.xlim(0.5, 20.5)

# Specifying tick locations explicitly to ensure all are shown, including 1 and 20
plt.xticks(np.arange(1, 21, step=1))

# Making tick labels larger
plt.tick_params(axis='both', which='major', labelsize=tick_size)  


if save_plot:
    format = 'pdf'
    if df_filename == 'all_runs_data_4.csv':
        savefolder = 'plots_bidder_multipliers_1_1'
    elif df_filename == 'all_runs_data_5.csv':
        savefolder = 'plots_bidder_multipliers_1_2'
    if metrics[0] == ('advertiser 0 expected value', 'advertiser 0 payment zero bid offset'):
        metrics = ['advertiser 0 utility zero bid offset', 'advertiser 1 utility zero bid offset']

    if not add_upper_limit:
        savename = f'./{savefolder}/{metrics[0]}_vs_{metrics[1]}_plot_contextual_llm_{use_input_expansion}.{format}'
    else:
        savename = f'./{savefolder}/{metrics[0]}_vs_{metrics[1]}_plot_with_upper_limit_contextual_llm_{use_input_expansion}.{format}'
    
    # replace all spaces with underscores
    savename = savename.replace(' ', '_')
    plt.savefig(savename)
else:
    plt.show()
